#  Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""Implementation of the SageMaker orchestrator."""

import json
import os
import re
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    Union,
    cast,
)
from uuid import UUID

import boto3
from botocore.exceptions import WaiterError
from sagemaker.estimator import Estimator
from sagemaker.inputs import TrainingInput
from sagemaker.network import NetworkConfig
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
from sagemaker.session import Session
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.execution_variables import (
    ExecutionVariable,
    ExecutionVariables,
)
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.step_collections import (
    StepCollection,
)
from sagemaker.workflow.steps import (
    ProcessingStep,
    Step,
    TrainingStep,
)
from sagemaker.workflow.triggers import PipelineSchedule

from zenml.client import Client
from zenml.config.base_settings import BaseSettings
from zenml.config.step_run_info import StepRunInfo
from zenml.constants import (
    METADATA_ORCHESTRATOR_LOGS_URL,
    METADATA_ORCHESTRATOR_RUN_ID,
    METADATA_ORCHESTRATOR_URL,
)
from zenml.enums import (
    ExecutionStatus,
    StackComponentType,
)
from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
    SagemakerOrchestratorConfig,
    SagemakerOrchestratorSettings,
)
from zenml.integrations.aws.orchestrators.sagemaker_orchestrator_entrypoint_config import (
    SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
    SagemakerDynamicPipelineEntrypointConfiguration,
    SagemakerEntrypointConfiguration,
)
from zenml.integrations.aws.step_operators.sagemaker_step_operator_entrypoint_config import (
    SagemakerStepOperatorEntrypointConfiguration,
)
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType, Uri
from zenml.orchestrators import ContainerizedOrchestrator, SubmissionResult
from zenml.orchestrators.utils import get_orchestrator_run_name
from zenml.stack import StackValidator
from zenml.utils.env_utils import split_environment_variables
from zenml.utils.time_utils import to_utc_timezone, utc_now_tz_aware

if TYPE_CHECKING:
    from zenml.models import PipelineRunResponse, PipelineSnapshotResponse
    from zenml.stack import Stack

ENV_ZENML_SAGEMAKER_RUN_ID = "ZENML_SAGEMAKER_RUN_ID"
MAX_POLLING_ATTEMPTS = 100
POLLING_DELAY = 30

logger = get_logger(__name__)


def dissect_schedule_arn(
    schedule_arn: str,
) -> Tuple[Optional[str], Optional[str]]:
    """Extracts the region and the name from an EventBridge schedule ARN.

    Args:
        schedule_arn: The ARN of the EventBridge schedule.

    Returns:
        Region Name, Schedule Name (including the group name)

    Raises:
        ValueError: If the input is not a properly formatted ARN.
    """
    # Split the ARN into parts
    arn_parts = schedule_arn.split(":")

    # Validate ARN structure
    if len(arn_parts) < 6 or not arn_parts[5].startswith("schedule/"):
        raise ValueError("Invalid EventBridge schedule ARN format.")

    # Extract the region
    region = arn_parts[3]

    # Extract the group name and schedule name
    name = arn_parts[5].split("schedule/")[1]

    return region, name


def construct_pipeline_execution_arn(
    session: Session,
    job_name: str,
    is_training_job: bool = True,
) -> str:
    """Construct the pipeline execution ARN.

    Args:
        session: The SageMaker session.
        job_name: The name of the job.
        is_training_job: Whether the job is a training job or a processing job.

    Returns:
        The pipeline execution ARN.
    """
    account_id = session.boto_session.client("sts").get_caller_identity()[
        "Account"
    ]
    region_name = session.boto_session.region_name
    job_type = "training" if is_training_job else "processing"

    return f"arn:aws:sagemaker:{region_name}:{account_id}:{job_type}-job/{job_name}"


def dissect_pipeline_execution_arn(
    pipeline_execution_arn: str,
) -> Tuple[
    Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]
]:
    """Extract region name, pipeline name, type and execution id from the ARN.

    Args:
        pipeline_execution_arn: the pipeline or training/processing job
            execution ARN

    Returns:
        Region Name, Account ID, Pipeline Type, Pipeline Name, Execution ID in order
    """
    match = re.search(
        r"sagemaker:(.*?):(.*?):pipeline/(.*?)/execution/(.*)",
        pipeline_execution_arn,
    )

    if match:
        return (
            match.group(1),
            match.group(2),
            "pipeline",
            match.group(3),
            match.group(4),
        )

    match = re.search(
        r"sagemaker:(.*?):(.*?):training-job/(.*)", pipeline_execution_arn
    )
    if match:
        return match.group(1), match.group(2), "training", match.group(3), None

    match = re.search(
        r"sagemaker:(.*?):(.*?):processing-job/(.*)",
        pipeline_execution_arn,
    )
    if match:
        return (
            match.group(1),
            match.group(2),
            "processing",
            match.group(3),
            None,
        )

    logger.warning(
        f"Unable to extract information from the execution ARN: "
        f"{pipeline_execution_arn}"
    )

    return None, None, None, None, None


class SagemakerOrchestrator(ContainerizedOrchestrator):
    """Orchestrator responsible for running pipelines or training/processing jobs on Sagemaker."""

    _sagemaker_session: Optional[Session] = None

    @property
    def config(self) -> SagemakerOrchestratorConfig:
        """Returns the `SagemakerOrchestratorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerOrchestratorConfig, self._config)

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        In the remote case, checks that the stack contains a container registry,
        image builder and only remote components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_remote_components(
            stack: "Stack",
        ) -> Tuple[bool, str]:
            for component in stack.components.values():
                if not component.config.is_local:
                    continue

                return False, (
                    f"The Sagemaker orchestrator runs pipelines remotely, "
                    f"but the '{component.name}' {component.type.value} is "
                    "a local stack component and will not be available in "
                    "the Sagemaker step.\nPlease ensure that you always "
                    "use non-local stack components with the Sagemaker "
                    "orchestrator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_orchestrator_run_id(self) -> str:
        """Returns the run id of the active orchestrator run.

        Important: This needs to be a unique ID and return the same value for
        all steps of a pipeline run.

        Returns:
            The orchestrator run id.

        Raises:
            RuntimeError: If the run id cannot be read from the environment.
        """
        for env in [ENV_ZENML_SAGEMAKER_RUN_ID, "TRAINING_JOB_ARN"]:
            if env in os.environ:
                return os.environ[env]

        config_file_path = "/opt/ml/config/processingjobconfig.json"

        if os.path.exists(config_file_path):
            try:
                with open(config_file_path, "r") as f:
                    config = json.load(f)

                processing_job_arn = config.get("ProcessingJobArn")
                if isinstance(processing_job_arn, str):
                    return processing_job_arn
            except Exception as e:
                logger.exception(
                    f"Error reading processing job config file: {e}"
                )

        raise RuntimeError("Unable to read run id from environment.")

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Sagemaker orchestrator.

        Returns:
            The settings class.
        """
        return SagemakerOrchestratorSettings

    @property
    def sagemaker_session(self) -> Session:
        """Returns the SageMaker session.

        Returns:
            The SageMaker session.
        """
        if self.connector_has_expired():
            self._sagemaker_session = None

        if self._sagemaker_session is None:
            self._sagemaker_session = self._get_sagemaker_session()
        return self._sagemaker_session

    def _get_sagemaker_session(self) -> Session:
        """Method to create the sagemaker session with proper authentication.

        Returns:
            The Sagemaker Session.

        Raises:
            RuntimeError: If the connector returns the wrong type for the
                session.
        """
        # Get authenticated session
        # Option 1: Service connector
        boto_session: boto3.Session
        if connector := self.get_connector():
            boto_session = connector.connect()
            if not isinstance(boto_session, boto3.Session):
                raise RuntimeError(
                    f"Expected to receive a `boto3.Session` object from the "
                    f"linked connector, but got type `{type(boto_session)}`."
                )
        # Option 2: Explicit configuration
        # Args that are not provided will be taken from the default AWS config.
        else:
            boto_session = boto3.Session(
                aws_access_key_id=self.config.aws_access_key_id,
                aws_secret_access_key=self.config.aws_secret_access_key,
                region_name=self.config.region,
                profile_name=self.config.aws_profile,
            )
            # If a role ARN is provided for authentication, assume the role
            if self.config.aws_auth_role_arn:
                sts = boto_session.client("sts")
                response = sts.assume_role(
                    RoleArn=self.config.aws_auth_role_arn,
                    RoleSessionName="zenml-sagemaker-orchestrator",
                )
                credentials = response["Credentials"]
                boto_session = boto3.Session(
                    aws_access_key_id=credentials["AccessKeyId"],
                    aws_secret_access_key=credentials["SecretAccessKey"],
                    aws_session_token=credentials["SessionToken"],
                    region_name=self.config.region,
                )
        return Session(
            boto_session=boto_session, default_bucket=self.config.bucket
        )

    def _create_job_or_step(
        self,
        session: Session,
        base_job_name: str,
        environment: Dict[str, str],
        image: str,
        settings: SagemakerOrchestratorSettings,
        command: List[str],
        arguments: List[str],
        run_id_variable: Optional[ExecutionVariable] = None,
        create_step: bool = False,
        step_name: Optional[str] = None,
        upstream_steps: Optional[List[str]] = None,
        wait: bool = True,
    ) -> Tuple[
        Union[Estimator, Processor, Step],
        Optional[str],
    ]:
        """Create and optionally run a SageMaker training or processing job or step.

        Args:
            session: The SageMaker session.
            base_job_name: The suffix used to construct the job name.
            environment: The environment variables to set in the execution
                environment.
            image: The image to use for the job.
            settings: The settings for the job.
            command: The command to use for the job.
            arguments: The arguments to use for the job.
            run_id_variable: The execution variable to use for the job run id.
                If not provided, the variable used for the training or
                processing job name will be used.
            create_step: Whether to create a pipeline step instead of a job.
            step_name: The name of the pipeline step.
            upstream_steps: The upstream steps of the pipeline step.
            wait: Whether to wait for the job to complete.

        Returns:
            The estimator or processor object or pipeline step object and the
            job name.

        Raises:
            TypeError: If the network_config passed is not compatible with the
                AWS SageMaker NetworkConfig class.
        """
        # Sagemaker does not allow environment variables longer than 256
        # characters to be passed to Processor steps. If an environment variable
        # is longer than 256 characters, we split it into multiple environment
        # variables (chunks) and re-construct it on the other side using the
        # custom entrypoint configuration.
        split_environment_variables(
            size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
            env=environment,
        )

        # Sagemaker requires the base job name to use alphanum and hyphens only
        base_job_name = re.sub(r"[^a-zA-Z0-9\-]", "-", base_job_name)

        use_training_step = (
            settings.use_training_step
            if settings.use_training_step is not None
            else (
                self.config.use_training_step
                if self.config.use_training_step is not None
                else True
            )
        )

        if use_training_step:
            job_args = settings.estimator_args.copy() or {}
            job_args.setdefault("volume_size", settings.volume_size_in_gb)
            job_args.setdefault("max_run", settings.max_runtime_in_seconds)
        else:
            job_args = settings.processor_args.copy() or {}
            job_args.setdefault(
                "volume_size_in_gb", settings.volume_size_in_gb
            )
            job_args.setdefault(
                "max_runtime_in_seconds",
                settings.max_runtime_in_seconds,
            )

        job_args.setdefault(
            "role",
            settings.execution_role or self.config.execution_role,
        )

        tags = settings.tags
        job_args.setdefault(
            "tags",
            (
                [{"Key": key, "Value": value} for key, value in tags.items()]
                if tags
                else None
            ),
        )

        job_args.setdefault(
            "instance_type", settings.instance_type or "ml.m5.large"
        )

        job_args["image_uri"] = image
        job_args["instance_count"] = 1
        job_args["sagemaker_session"] = session

        # Convert network_config to sagemaker.network.NetworkConfig if
        # present
        network_config = job_args.get("network_config")

        if network_config and isinstance(network_config, dict):
            try:
                job_args["network_config"] = NetworkConfig(**network_config)
            except TypeError:
                # If the network_config passed is not compatible with the
                # NetworkConfig class, raise a more informative error.
                raise TypeError(
                    "Expected a sagemaker.network.NetworkConfig "
                    "compatible object for the network_config argument, "
                    "but the network_config processor argument is invalid."
                    "See https://sagemaker.readthedocs.io/en/stable/api/utility/network.html#sagemaker.network.NetworkConfig "
                    "for more information about the NetworkConfig class."
                )

        # Construct S3 inputs to container for step
        training_inputs: Optional[
            Union[TrainingInput, Dict[str, TrainingInput]]
        ] = None
        processing_inputs: Optional[List[ProcessingInput]] = None

        if settings.input_data_s3_uri is None:
            pass
        elif isinstance(settings.input_data_s3_uri, str):
            if use_training_step:
                training_inputs = TrainingInput(
                    s3_data=settings.input_data_s3_uri,
                    input_mode=settings.input_data_s3_mode,
                )
            else:
                processing_inputs = [
                    ProcessingInput(
                        source=settings.input_data_s3_uri,
                        destination="/opt/ml/processing/input/data",
                        s3_input_mode=settings.input_data_s3_mode,
                    )
                ]
        elif isinstance(settings.input_data_s3_uri, dict):
            if use_training_step:
                training_inputs = {}
                for (
                    channel,
                    s3_uri,
                ) in settings.input_data_s3_uri.items():
                    training_inputs[channel] = TrainingInput(
                        s3_data=s3_uri,
                        input_mode=settings.input_data_s3_mode,
                    )
            else:
                processing_inputs = []
                for (
                    channel,
                    s3_uri,
                ) in settings.input_data_s3_uri.items():
                    processing_inputs.append(
                        ProcessingInput(
                            source=s3_uri,
                            destination=f"/opt/ml/processing/input/data/{channel}",
                            s3_input_mode=settings.input_data_s3_mode,
                        )
                    )

        # Construct S3 outputs from container for step
        outputs = None
        output_path = None

        if settings.output_data_s3_uri is None:
            pass
        elif isinstance(settings.output_data_s3_uri, str):
            if use_training_step:
                output_path = settings.output_data_s3_uri
            else:
                outputs = [
                    ProcessingOutput(
                        source="/opt/ml/processing/output/data",
                        destination=settings.output_data_s3_uri,
                        s3_upload_mode=settings.output_data_s3_mode,
                    )
                ]
        elif isinstance(settings.output_data_s3_uri, dict):
            outputs = []
            for (
                channel,
                s3_uri,
            ) in settings.output_data_s3_uri.items():
                outputs.append(
                    ProcessingOutput(
                        source=f"/opt/ml/processing/output/data/{channel}",
                        destination=s3_uri,
                        s3_upload_mode=settings.output_data_s3_mode,
                    )
                )

        final_environment: Dict[str, Union[str, PipelineVariable]] = {
            key: str(value) for key, value in environment.items()
        }

        if run_id_variable:
            final_environment[ENV_ZENML_SAGEMAKER_RUN_ID] = run_id_variable

        if use_training_step:
            estimator = Estimator(
                base_job_name=base_job_name,
                keep_alive_period_in_seconds=settings.keep_alive_period_in_seconds,
                output_path=output_path,
                environment=final_environment,
                container_entry_point=command,
                container_arguments=arguments,
                **job_args,
            )

            if create_step:
                assert step_name is not None, "Step name is required"
                return TrainingStep(
                    name=step_name,
                    depends_on=cast(
                        Optional[List[Union[str, Step, StepCollection]]],
                        upstream_steps,
                    ),
                    inputs=training_inputs,
                    estimator=estimator,
                ), None

            estimator.fit(
                wait=wait,
                inputs=training_inputs,
            )

            assert estimator.latest_training_job is not None, (
                "No training job found"
            )
            job_name = estimator.latest_training_job.job_name

            return estimator, job_name

        else:
            processor = Processor(
                base_job_name=base_job_name,
                entrypoint=cast(
                    Optional[List[Union[str, PipelineVariable]]],
                    command + arguments,
                ),
                env=final_environment,
                **job_args,
            )

            if create_step:
                assert step_name is not None, "Step name is required"
                return ProcessingStep(
                    name=step_name,
                    processor=processor,
                    depends_on=cast(
                        Optional[List[Union[str, Step, StepCollection]]],
                        upstream_steps,
                    ),
                    inputs=processing_inputs,
                    outputs=outputs,
                ), None

            processor.run(
                wait=wait,
                logs=wait,
                inputs=processing_inputs,
                outputs=outputs,
            )

            assert processor.latest_job is not None, "No processing job found"
            job_name = processor.latest_job.job_name

            return processor, job_name

    def submit_pipeline(
        self,
        snapshot: "PipelineSnapshotResponse",
        stack: "Stack",
        base_environment: Dict[str, str],
        step_environments: Dict[str, Dict[str, str]],
        placeholder_run: Optional["PipelineRunResponse"] = None,
    ) -> Optional[SubmissionResult]:
        """Submits a pipeline to the orchestrator.

        This method should only submit the pipeline and not wait for it to
        complete. If the orchestrator is configured to wait for the pipeline run
        to complete, a function that waits for the pipeline run to complete can
        be passed as part of the submission result.

        Args:
            snapshot: The pipeline snapshot to submit.
            stack: The stack the pipeline will run on.
            base_environment: Base environment shared by all steps. This should
                be set if your orchestrator for example runs one container that
                is responsible for starting all the steps.
            step_environments: Environment variables to set when executing
                specific steps.
            placeholder_run: An optional placeholder run for the snapshot.

        Raises:
            RuntimeError: If there is an error creating or scheduling the
                pipeline.
            ValueError: If the schedule is not valid.

        Returns:
            Optional submission result.
        """
        # sagemaker requires pipelineName to use alphanum and hyphens only
        unsanitized_orchestrator_run_name = get_orchestrator_run_name(
            pipeline_name=snapshot.pipeline_configuration.name
        )
        # replace all non-alphanum and non-hyphens with hyphens
        orchestrator_run_name = re.sub(
            r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
        )

        session = self.sagemaker_session

        sagemaker_steps = []
        for step_name, step in snapshot.step_configurations.items():
            step_environment = step_environments[step_name]

            image = self.get_image(snapshot=snapshot, step_name=step_name)
            command = SagemakerEntrypointConfiguration.get_entrypoint_command()
            arguments = (
                SagemakerEntrypointConfiguration.get_entrypoint_arguments(
                    step_name=step_name, snapshot_id=snapshot.id
                )
            )

            step_settings = cast(
                SagemakerOrchestratorSettings, self.get_settings(step)
            )

            if step_settings.environment:
                step_environment.update(step_settings.environment)

            sagemaker_step, _ = self._create_job_or_step(
                session=session,
                base_job_name=f"{snapshot.pipeline_configuration.name}-{step_name}",
                environment=step_environment,
                image=image,
                settings=step_settings,
                command=command,
                arguments=arguments,
                run_id_variable=ExecutionVariables.PIPELINE_EXECUTION_ARN,
                create_step=True,
                step_name=step_name,
                upstream_steps=step.spec.upstream_steps,
                wait=False,
            )

            sagemaker_steps.append(sagemaker_step)

        # Create the pipeline
        pipeline = Pipeline(
            name=orchestrator_run_name,
            steps=sagemaker_steps,
            sagemaker_session=session,
        )

        settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(snapshot)
        )

        pipeline.create(
            role_arn=self.config.execution_role,
            tags=[
                {"Key": key, "Value": value}
                for key, value in settings.pipeline_tags.items()
            ]
            if settings.pipeline_tags
            else None,
        )

        # Handle scheduling if specified
        if snapshot.schedule:
            if settings.synchronous:
                logger.warning(
                    "The 'synchronous' setting is ignored for scheduled "
                    "pipelines since they run independently of the "
                    "deployment process."
                )

            schedule_name = orchestrator_run_name
            next_execution = None
            start_date = (
                to_utc_timezone(snapshot.schedule.start_time)
                if snapshot.schedule.start_time
                else None
            )

            # Create PipelineSchedule based on schedule type
            if snapshot.schedule.cron_expression:
                cron_exp = self._validate_cron_expression(
                    snapshot.schedule.cron_expression
                )
                schedule = PipelineSchedule(
                    name=schedule_name,
                    cron=cron_exp,
                    start_date=start_date,
                    enabled=True,
                )
            elif snapshot.schedule.interval_second:
                # This is necessary because SageMaker's PipelineSchedule rate
                # expressions require minutes as the minimum time unit.
                # Even if a user specifies an interval of less than 60 seconds,
                # it will be rounded up to 1 minute.
                minutes = max(
                    1,
                    int(
                        snapshot.schedule.interval_second.total_seconds() / 60
                    ),
                )
                schedule = PipelineSchedule(
                    name=schedule_name,
                    rate=(minutes, "minutes"),
                    start_date=start_date,
                    enabled=True,
                )
                next_execution = (
                    snapshot.schedule.start_time or utc_now_tz_aware()
                ) + snapshot.schedule.interval_second
            else:
                # One-time schedule
                execution_time = (
                    snapshot.schedule.run_once_start_time
                    or snapshot.schedule.start_time
                )
                if not execution_time:
                    raise ValueError(
                        "A start time must be specified for one-time "
                        "schedule execution"
                    )
                schedule = PipelineSchedule(
                    name=schedule_name,
                    at=to_utc_timezone(execution_time),
                    enabled=True,
                )
                next_execution = execution_time

            # Get the current role ARN if not explicitly configured
            if self.config.scheduler_role is None:
                logger.info(
                    "No scheduler_role configured. Trying to extract it from "
                    "the client side authentication."
                )
                sts = session.boto_session.client("sts")
                try:
                    scheduler_role_arn = sts.get_caller_identity()["Arn"]
                    # If this is a user ARN, try to get the role ARN
                    if ":user/" in scheduler_role_arn:
                        logger.warning(
                            f"Using IAM user credentials "
                            f"({scheduler_role_arn}). For production "
                            "environments, it's recommended to use IAM roles "
                            "instead."
                        )
                    # If this is an assumed role, extract the role ARN
                    elif ":assumed-role/" in scheduler_role_arn:
                        # Convert assumed-role ARN format to role ARN format
                        # From: arn:aws:sts::123456789012:assumed-role/role-name/session-name
                        # To: arn:aws:iam::123456789012:role/role-name
                        scheduler_role_arn = re.sub(
                            r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*",
                            r"arn:aws:iam::\1:role/\2",
                            scheduler_role_arn,
                        )
                    elif ":role/" not in scheduler_role_arn:
                        raise RuntimeError(
                            f"Unexpected credential type "
                            f"({scheduler_role_arn}). Please use IAM "
                            f"roles for SageMaker pipeline scheduling."
                        )
                    else:
                        raise RuntimeError(
                            "The ARN of the caller identity "
                            f"`{scheduler_role_arn}` does not "
                            "include a user or a proper role."
                        )
                except Exception:
                    raise RuntimeError(
                        "Failed to get current role ARN. This means the "
                        "your client side credentials are not configured "
                        "correctly to schedule SageMaker pipelines. "
                        "For more information, please check:"
                        "https://docs.zenml.io/stacks/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules"
                    )
            else:
                scheduler_role_arn = self.config.scheduler_role

            # Attach schedule to pipeline
            triggers = pipeline.put_triggers(
                triggers=[schedule],
                role_arn=scheduler_role_arn,
            )
            logger.info(f"The schedule ARN is: {triggers[0]}")

            schedule_metadata = {}
            try:
                schedule_metadata = self.generate_schedule_metadata(
                    schedule_arn=triggers[0]
                )
            except Exception as e:
                logger.debug(
                    "There was an error generating schedule metadata: %s", e
                )

            logger.info(
                f"Successfully scheduled pipeline with name: {schedule_name}\n"
                + (
                    f"First execution will occur at: "
                    f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}"
                    if next_execution
                    else f"Using cron expression: "
                    f"{snapshot.schedule.cron_expression}"
                )
                + (
                    f" (and every {minutes} minutes after)"
                    if snapshot.schedule.interval_second
                    else ""
                )
            )
            logger.info(
                "\n\nIn order to cancel the schedule, you can use execute "
                "the following command:\n"
            )
            logger.info(
                f"`aws scheduler delete-schedule --name {schedule_name}`"
            )
            return SubmissionResult(metadata=schedule_metadata)
        else:
            # Execute the pipeline immediately if no schedule is specified
            execution = pipeline.start()
            logger.warning(
                "Steps can take 5-15 minutes to start running "
                "when using the Sagemaker Orchestrator."
            )

            run_metadata = self.compute_metadata(
                execution_arn=execution.arn, settings=settings
            )

            _wait_for_completion = None
            if settings.synchronous:

                def _wait_for_completion() -> None:
                    logger.info(
                        "Executing synchronously. Waiting for pipeline to "
                        "finish... \n"
                        "At this point you can `Ctrl-C` out without cancelling the "
                        "execution."
                    )
                    try:
                        execution.wait(
                            delay=POLLING_DELAY,
                            max_attempts=MAX_POLLING_ATTEMPTS,
                        )
                        logger.info("Pipeline completed successfully.")
                    except WaiterError:
                        raise RuntimeError(
                            "Timed out while waiting for pipeline execution to "
                            "finish. For long-running pipelines we recommend "
                            "configuring your orchestrator for asynchronous "
                            "execution. The following command does this for you: \n"
                            f"`zenml orchestrator update {self.name} "
                            f"--synchronous=False`"
                        )

            return SubmissionResult(
                wait_for_completion=_wait_for_completion,
                metadata=run_metadata,
            )

    def submit_dynamic_pipeline(
        self,
        snapshot: "PipelineSnapshotResponse",
        stack: "Stack",
        environment: Dict[str, str],
        placeholder_run: Optional["PipelineRunResponse"] = None,
    ) -> Optional[SubmissionResult]:
        """Submits a dynamic pipeline to the orchestrator.

        Args:
            snapshot: The snapshot of the pipeline.
            stack: The stack to use for the pipeline.
            environment: The environment variables to set in the pipeline.
            placeholder_run: The placeholder run for the pipeline.

        Returns:
            Optional submission result.

        Raises:
            NotImplementedError: If the pipeline is scheduled.
        """
        if snapshot.schedule:
            raise NotImplementedError(
                "The AWS SageMaker Orchestrator does not currently support "
                "scheduling for dynamic pipelines."
            )

        session = self.sagemaker_session

        try:
            image = self.get_image(snapshot=snapshot)
        except KeyError:
            # If no generic pipeline image exists (which means all steps of a
            # static pipeline have custom builds) we use a random step image as
            # all of them include dependencies for the active stack
            invocation_id = next(iter(snapshot.step_configurations))
            image = self.get_image(snapshot=snapshot, step_name=invocation_id)

        settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(snapshot)
        )

        command = SagemakerDynamicPipelineEntrypointConfiguration.get_entrypoint_command()
        arguments = SagemakerDynamicPipelineEntrypointConfiguration.get_entrypoint_arguments(
            snapshot_id=snapshot.id,
            run_id=placeholder_run.id if placeholder_run else None,
        )

        job, job_name = self._create_job_or_step(
            session=session,
            base_job_name=snapshot.pipeline_configuration.name,
            environment=environment,
            image=image,
            settings=settings,
            command=command,
            arguments=arguments,
            wait=False,
        )

        assert job_name is not None, "Job name is not set"

        is_training_job = isinstance(job, Estimator)

        _wait_for_completion = None
        if settings.synchronous:

            def _wait_for_completion() -> None:
                logger.info(
                    "Executing synchronously. Waiting for pipeline to "
                    "finish... \n"
                    "At this point you can `Ctrl-C` out without cancelling the "
                    "execution."
                )
                try:
                    if is_training_job:
                        session.logs_for_job(
                            job_name,
                            wait=True,
                            poll=POLLING_DELAY,
                            timeout=POLLING_DELAY * MAX_POLLING_ATTEMPTS,
                        )
                    else:
                        session.logs_for_processing_job(
                            job_name,
                            wait=True,
                            poll=POLLING_DELAY,
                        )

                    logger.info("Pipeline completed successfully.")
                except WaiterError:
                    raise RuntimeError(
                        "Timed out while waiting for pipeline execution to "
                        "finish. For long-running pipelines we recommend "
                        "configuring your orchestrator for asynchronous "
                        "execution. The following command does this for you: \n"
                        f"`zenml orchestrator update {self.name} "
                        f"--synchronous=False`"
                    )

        metadata = self.compute_metadata(
            execution_arn=construct_pipeline_execution_arn(
                session=session,
                job_name=job_name,
                is_training_job=is_training_job,
            ),
            settings=settings,
        )

        return SubmissionResult(
            wait_for_completion=_wait_for_completion,
            metadata=metadata,
        )

    def run_isolated_step(
        self, step_run_info: "StepRunInfo", environment: Dict[str, str]
    ) -> None:
        """Runs an isolated step on Sagemaker.

        Args:
            step_run_info: The step run information.
            environment: The environment variables to set in the execution
                environment.
        """
        logger.info(
            "Launching job for step `%s`.",
            step_run_info.pipeline_step_name,
        )

        session = self.sagemaker_session

        image = self.get_image(
            snapshot=step_run_info.snapshot,
            step_name=step_run_info.pipeline_step_name,
        )

        settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step_run_info)
        )

        command = SagemakerStepOperatorEntrypointConfiguration.get_entrypoint_command()
        arguments = SagemakerStepOperatorEntrypointConfiguration.get_entrypoint_arguments(
            step_name=step_run_info.pipeline_step_name,
            snapshot_id=step_run_info.snapshot.id,
            step_run_id=str(step_run_info.step_run_id),
        )

        self._create_job_or_step(
            session=session,
            base_job_name=f"{step_run_info.pipeline.name}-{step_run_info.pipeline_step_name}",
            environment=environment,
            image=image,
            settings=settings,
            command=command,
            arguments=arguments,
            wait=True,
        )

    def get_pipeline_run_metadata(
        self, run_id: UUID
    ) -> Dict[str, "MetadataType"]:
        """Get general component-specific metadata for a pipeline run.

        Args:
            run_id: The ID of the pipeline run.

        Returns:
            A dictionary of metadata.
        """
        execution_arn = self.get_orchestrator_run_id()

        settings = cast(
            SagemakerOrchestratorSettings,
            self.get_settings(Client().get_pipeline_run(run_id)),
        )

        return self.compute_metadata(
            execution_arn=execution_arn,
            settings=settings,
        )

    def fetch_status(
        self, run: "PipelineRunResponse", include_steps: bool = False
    ) -> Tuple[
        Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]]
    ]:
        """Refreshes the status of a specific pipeline run.

        Args:
            run: The run that was executed by this orchestrator.
            include_steps: Whether to fetch steps

        Returns:
            A tuple of (pipeline_status, step_statuses_dict).
            Step statuses are not supported for SageMaker, so step_statuses_dict will always be None.

        Raises:
            AssertionError: If the run was not executed by to this orchestrator.
            ValueError: If it fetches an unknown state or if we can not fetch
                the orchestrator run ID.
            ValueError: If the orchestrator run ID cannot be used to identify
                the pipeline type.
        """
        # Make sure that the stack exists and is accessible
        if run.stack is None:
            raise ValueError(
                "The stack that the run was executed on is not available "
                "anymore."
            )

        # Make sure that the run belongs to this orchestrator
        assert (
            self.id
            == run.stack.components[StackComponentType.ORCHESTRATOR][0].id
        )

        # Initialize the Sagemaker client
        session = self.sagemaker_session
        sagemaker_client = session.sagemaker_client

        # Fetch the status of the _PipelineExecution
        if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
            run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
        elif run.orchestrator_run_id is not None:
            run_id = run.orchestrator_run_id
        else:
            raise ValueError(
                "Can not find the orchestrator run ID, thus can not fetch "
                "the status."
            )

        _, _, pipeline_type, pipeline_name, _ = dissect_pipeline_execution_arn(
            str(run_id)
        )

        if pipeline_type == "pipeline":
            status = sagemaker_client.describe_pipeline_execution(
                PipelineExecutionArn=run_id
            )["PipelineExecutionStatus"]

            # Map the potential outputs to ZenML ExecutionStatus. Potential values:
            # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribePipelineExecution.html
            if status == "Executing":
                pipeline_status = ExecutionStatus.RUNNING
            elif status == "Stopping":
                pipeline_status = ExecutionStatus.STOPPING
            elif status == "Stopped":
                pipeline_status = ExecutionStatus.STOPPED
            elif status == "Failed":
                pipeline_status = ExecutionStatus.FAILED
            elif status == "Succeeded":
                pipeline_status = ExecutionStatus.COMPLETED
            else:
                raise ValueError("Unknown status for the pipeline execution.")

            # SageMaker doesn't support step-level status fetching yet
            return pipeline_status, None

        elif pipeline_type == "training":
            status = sagemaker_client.describe_training_job(
                TrainingJobName=pipeline_name
            )["TrainingJobStatus"]

            # Map the potential outputs to ZenML ExecutionStatus. Potential values:
            # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeTrainingJob.html#sagemaker-DescribeTrainingJob-response-TrainingJobStatus
            if status == "InProgress":
                pipeline_status = ExecutionStatus.RUNNING
            elif status == "Stopping":
                pipeline_status = ExecutionStatus.STOPPING
            elif status == "Completed":
                pipeline_status = ExecutionStatus.COMPLETED
            elif status == "Failed":
                pipeline_status = ExecutionStatus.FAILED
            elif status == "Stopped":
                pipeline_status = ExecutionStatus.STOPPED
            else:
                raise ValueError("Unknown status for the training job.")

            return pipeline_status, None

        elif pipeline_type == "processing":
            status = sagemaker_client.describe_processing_job(
                ProcessingJobName=pipeline_name
            )["ProcessingJobStatus"]

            # Map the potential outputs to ZenML ExecutionStatus. Potential values:
            # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeProcessingJob.html#sagemaker-DescribeProcessingJob-response-ProcessingJobStatus
            if status == "InProgress":
                pipeline_status = ExecutionStatus.RUNNING
            elif status == "Stopping":
                pipeline_status = ExecutionStatus.STOPPING
            elif status == "Completed":
                pipeline_status = ExecutionStatus.COMPLETED
            elif status == "Failed":
                pipeline_status = ExecutionStatus.FAILED
            else:
                raise ValueError("Unknown status for the processing job.")

            return pipeline_status, None

        else:
            raise ValueError("Unknown pipeline type.")

    def compute_metadata(
        self,
        execution_arn: str,
        settings: SagemakerOrchestratorSettings,
    ) -> Dict[str, MetadataType]:
        """Generate run metadata based on the generated Sagemaker Execution.

        Args:
            execution_arn: The ARN of the pipeline execution.
            settings: The Sagemaker orchestrator settings.

        Returns:
            A dictionary of metadata related to the pipeline run.
        """
        use_training_job = True
        if settings.use_training_step is not None:
            use_training_job = settings.use_training_step

        # Orchestrator Run ID
        metadata: Dict[str, MetadataType] = {
            "pipeline_execution_arn": execution_arn,
            METADATA_ORCHESTRATOR_RUN_ID: execution_arn,
        }

        # URL to the Sagemaker's pipeline view
        if orchestrator_url := self._compute_orchestrator_url(
            execution_arn=execution_arn, use_training_job=use_training_job
        ):
            metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url)

        # URL to the corresponding CloudWatch page
        if logs_url := self._compute_orchestrator_logs_url(
            execution_arn=execution_arn, use_training_job=use_training_job
        ):
            metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url)

        return metadata

    def _compute_orchestrator_url(
        self,
        execution_arn: Any,
        use_training_job: bool,
    ) -> Optional[str]:
        """Generate the Orchestrator Dashboard URL upon pipeline execution.

        Args:
            execution_arn: The ARN of the pipeline execution.
            use_training_job: Whether the pipeline uses the training job type or
                the processing job type.

        Returns:
            The URL to the dashboard view in SageMaker.
        """
        (
            region_name,
            account_id,
            pipeline_type,
            pipeline_name,
            execution_id,
        ) = dissect_pipeline_execution_arn(execution_arn)

        if pipeline_type is None:
            return None

        if pipeline_type != "processing":
            # Processing jobs are not supported in SageMaker Studio
            try:
                # We try to get a SageMaker Studio URL if available
                session = self.sagemaker_session

                # List the Studio domains and get the Studio Domain ID
                domains_response = session.sagemaker_client.list_domains()
                studio_domain_id = domains_response["Domains"][0]["DomainId"]

                if pipeline_type == "pipeline":
                    return (
                        f"https://studio-{studio_domain_id}.studio.{region_name}."
                        f"sagemaker.aws/pipelines/view/{pipeline_name}/executions"
                        f"/{execution_id}/graph"
                    )
                else:  # training:
                    return (
                        f"https://studio-{studio_domain_id}.studio.{region_name}."
                        f"sagemaker.aws/jobs/train/{pipeline_name}"
                    )
            except Exception as e:
                logger.warning(
                    f"There was an issue while extracting the SageMaker Studio "
                    f"URL: {e}"
                )

        if pipeline_type == "pipeline":
            job_type = "jobs" if use_training_job else "processing-jobs"
            return (
                f"https://{account_id}.{region_name}.console.aws.amazon.com"
                f"/sagemaker/home?region={region_name}#/{job_type}"
            )

        job_type = "jobs" if pipeline_type == "training" else "processing-jobs"
        return (
            f"https://{account_id}.{region_name}.console.aws.amazon.com"
            f"/sagemaker/home?region={region_name}#/{job_type}"
            f"/{pipeline_name}"
        )

    @staticmethod
    def _compute_orchestrator_logs_url(
        execution_arn: Any,
        use_training_job: bool,
    ) -> Optional[str]:
        """Generate the CloudWatch URL upon pipeline execution.

        Args:
            execution_arn: The ARN of the pipeline execution.
            use_training_job: Whether the pipeline uses the training job type or
                the processing job type.

        Returns:
            the URL querying the pipeline logs in CloudWatch on AWS.
        """
        (
            region_name,
            _,
            pipeline_type,
            pipeline_name,
            execution_id,
        ) = dissect_pipeline_execution_arn(execution_arn)

        if pipeline_type is None:
            return None

        job_type = "Training" if use_training_job else "Processing"

        if pipeline_type == "pipeline":
            log_filter = f"pipelines-{execution_id}"
        else:
            log_filter = f"{pipeline_name}"

        return (
            f"https://{region_name}.console.aws.amazon.com/"
            f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
            f"/$252Faws$252Fsagemaker$252F{job_type}Jobs$3FlogStreamNameFilter"
            f"$3D{log_filter}"
        )

    @staticmethod
    def generate_schedule_metadata(
        schedule_arn: str,
    ) -> Dict[str, MetadataType]:
        """Attaches metadata to the ZenML Schedules.

        Args:
            schedule_arn: The trigger ARNs that is generated on the AWS side.

        Returns:
            a dictionary containing metadata related to the schedule.
        """
        region, name = dissect_schedule_arn(schedule_arn=schedule_arn)

        return {
            "trigger_url": (
                f"https://{region}.console.aws.amazon.com/scheduler/home"
                f"?region={region}#schedules/{name}"
            ),
        }

    @staticmethod
    def _validate_cron_expression(cron_expression: str) -> str:
        """Validates and formats a cron expression for SageMaker schedules.

        Args:
            cron_expression: The cron expression to validate

        Returns:
            The formatted cron expression

        Raises:
            ValueError: If the cron expression is invalid
        """
        # Strip any "cron(" prefix if it exists
        cron_exp = cron_expression.replace("cron(", "").replace(")", "")

        # Split into components
        parts = cron_exp.split()
        if len(parts) not in [6, 7]:  # AWS cron requires 6 or 7 fields
            raise ValueError(
                f"Invalid cron expression: {cron_expression}. AWS cron "
                "expressions must have 6 or 7 fields: minute hour day-of-month "
                "month day-of-week year(optional). Example: '15 10 ? * 6L "
                "2022-2023'"
            )

        return cron_exp
